package ip

import (
	"bytes"
	"encoding/binary"
	"errors"
	"fmt"
	"golang.org/x/text/encoding/simplifiedchinese"
	"io"
	"io/ioutil"
	"log"
	"net"
	"os"
	"strconv"
	"strings"
)

const (
	recordLength  = uint32(7)
	redirectMode1 = uint8(0x1)
	redirectMode2 = uint8(0x2)
)

type database struct {
	fp            *os.File
	beginIpOffset uint32
	lastIpOffset  uint32
	totalIpNum    uint32
}

func getDbFile() string {
	p := os.TempDir() + string(os.PathSeparator) + "ip.db"
	if _, e := os.Stat(p); e != nil {
		b, _ := Asset("data/qqwry.dat")
		e := ioutil.WriteFile(p, b, 0644)
		if e != nil {
			log.Fatal(e)
		}
	}
	return p
}

// getDB 获取数据库, 单例模式
func getDB() (*database, error) {
	db := &database{}
	if e := db.initialize(); e != nil {
		return nil, e
	} else {
		return db, nil
	}
}

// initialize 初始化数据库
func (db *database) initialize() (e error) {
	db.fp, e = os.Open(getDbFile())
	if e != nil {
		return errors.New("Open database file failed: " + e.Error())
	}

	db.beginIpOffset = db.readBytesAsUint32(4) // begin ip pos, 4bytes
	db.lastIpOffset = db.readBytesAsUint32(4)  // end ip pos, 4bytes
	db.totalIpNum = (db.lastIpOffset - db.beginIpOffset) / recordLength

	return nil
}

// query 查询IP地址
func (db *database) query(ipStr string) (Location, error) {
	var e error

	ip := ip2long(ipStr)

	// 二分查找IP地址
	var findIpPos, l, u uint32
	u = db.totalIpNum
	for l <= u {
		i := (l + u) / 2
		db.forward(db.beginIpOffset + i*recordLength)
		if ip < db.readBytesAsUint32(4) {
			u = i - 1
		} else {
			db.forward(db.readBytesAsUint32(3))
			if ip > db.readBytesAsUint32(4) {
				l = i + 1
			} else {
				findIpPos = db.beginIpOffset + i*recordLength
				break
			}
		}
	}

	db.forward(findIpPos)
	beginIP := long2ip(db.readBytesAsUint32(4)) // 开始IP区间
	offset := db.readBytesAsUint32(3)

	db.forward(offset)
	endIP := long2ip(db.readBytesAsUint32(4)) // 结束IP区间

	// 获取国家和区域
	var country, area []byte
	b := db.readByte()
	switch b {
	case redirectMode1:
		countryOffset := db.readBytesAsUint32(3)
		db.forward(countryOffset)
		b2 := db.readByte()
		switch b2 {
		case redirectMode2:
			db.forward(db.readBytesAsUint32(3))
			country = db.getString()

			db.forward(countryOffset + 4)
			area = db.getArea()
		default:
			country = append([]byte{b2}, db.getString()...)
			area = db.getArea()
		}
	case redirectMode2:
		db.forward(db.readBytesAsUint32(3))
		country = db.getString()

		db.forward(offset + 8)
		area = db.getArea()
	default:
		country = append([]byte{b}, db.getString()...)
		area = db.getArea()
	}

	location := Location{
		IP:      ipStr,
		BeginIP: beginIP,
		EndIP:   endIP,
		Country: toUTF8(country),
		Area:    toUTF8(area),
	}

	return location, e
}

// getString 获取字符串直到碰到\0
func (db *database) getString() []byte {
	var b []byte
	c := db.readByte()
	for c > 0 {
		b = append(b, c)
		c = db.readByte()
	}
	return b
}

// getArea 获取区域
func (db *database) getArea() []byte {
	f := db.readByte()
	switch f {
	case 0:
		return []byte{}
	case 1:
		fallthrough
	case 2:
		db.forward(db.readBytesAsUint32(3))
		return db.getString()
	default:
		bs := db.getString()
		return append([]byte{f}, bs...)

	}
}

// readByte 读取一个字节
func (db *database) readByte() byte {
	b := make([]byte, 1)
	_, e := db.fp.Read(b)
	if e != nil {
		log.Println("raed byte failed:", e)
	}
	return b[0]
}

// readBytesAsUint32 读取指定字节的数据并转换为uint32
func (db *database) readBytesAsUint32(n int) uint32 {
	b := make([]byte, n)
	_, e := db.fp.Read(b)
	if e != nil {
		log.Println("raed bytes failed:", e)
	}

	if len(b) == 3 {
		b = append(b, 0)
	}
	var t uint32
	buf := bytes.NewBuffer(b)
	e = binary.Read(buf, binary.LittleEndian, &t)
	if e != nil {
		log.Println("convert into uint32 failed:", e)
	}
	return t
}

// forward 前进到指定offset位置
func (db *database) forward(offset uint32) {
	_, e := db.fp.Seek(int64(offset), io.SeekStart)
	if e != nil {
		log.Println("forward failed")
	}
}

// ip2long IP转换为uint32
func ip2long(ip string) uint32 {
	var ip1, ip2, ip3, ip4 uint32
	to32 := func(s string) uint32 {
		i, _ := strconv.Atoi(s)
		return uint32(i)
	}
	arr := strings.Split(ip, ".")
	ip1 = to32(arr[0])
	ip2 = to32(arr[1])
	ip3 = to32(arr[2])
	ip4 = to32(arr[3])

	return (ip1 << 24) | (ip2 << 16) | (ip3 << 8) | ip4
}

// long2ip uint32转换为ip地址
func long2ip(ip uint32) string {
	return fmt.Sprintf("%d.%d.%d.%d", ip>>24, ip<<8>>24, ip<<16>>24, ip<<24>>24)
}

// toUTF8 转换为utf8格式
func toUTF8(byte []byte) string {
	var decodeBytes, _ = simplifiedchinese.GBK.NewDecoder().Bytes(byte)
	return string(decodeBytes)
}

// FindByStr 根据IP地址查询
func FindByStr(ip string) (loc Location, e error) {
	d, e := getDB()
	if e != nil {
		return loc, e
	}
	return d.query(ip)
}

// Find 根据net.IP查询
func Find(ip net.IP) (loc Location, e error) {
	return FindByStr(ip.String())
}
